Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

share apply_strategy method between autounit and autopredictunit #512

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

JKSenthil
Copy link
Contributor

Summary:

Context:

Both AutoUnit and AutoPredictUnit use the same code block to apply the strategy on the module and check for any incompatibilties:

if strategy:
    if isinstance(strategy, str):
        strategy = _convert_str_to_strategy(strategy)
    if isinstance(strategy, DDPStrategy):
        if torch_compile_params and strategy.static_graph is True:
            # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
            raise RuntimeError(
                "Torch compile requires DDPStrategy's static_graph to be False"
            )
        module = prepare_ddp(module, self.device, strategy)
    elif isinstance(strategy, FSDPStrategy):
        if swa_params:
            raise RuntimeError(
                "Stochastic Weight Averaging is currently not supported with the FSDP strategy"
            )
        # as stated here https://pytorch.org/get-started/pytorch-2.0/
        rank_zero_warn(
            "We recommend setting FSDPStrategy's use_original_params to True when using torch compile."
        )
        module = prepare_fsdp(module, self.device, strategy)
else:
    module = module.to(self.device)

If changes are made to this logic, they must be made in both of those classes, which can be easily missed

This Diff

Creates helper function _apply_strategy_and_check(...) to apply the strategy on the module and calls this function in both AutoUnit and AutoPredictUnit (other name suggestions are also welcome)

Differential Revision: D48612629

Summary:
## Context:

Both `AutoUnit` and `AutoPredictUnit` use the same code block to apply the strategy on the module and check for any incompatibilties:

```
if strategy:
    if isinstance(strategy, str):
        strategy = _convert_str_to_strategy(strategy)
    if isinstance(strategy, DDPStrategy):
        if torch_compile_params and strategy.static_graph is True:
            # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
            raise RuntimeError(
                "Torch compile requires DDPStrategy's static_graph to be False"
            )
        module = prepare_ddp(module, self.device, strategy)
    elif isinstance(strategy, FSDPStrategy):
        if swa_params:
            raise RuntimeError(
                "Stochastic Weight Averaging is currently not supported with the FSDP strategy"
            )
        # as stated here https://pytorch.org/get-started/pytorch-2.0/
        rank_zero_warn(
            "We recommend setting FSDPStrategy's use_original_params to True when using torch compile."
        )
        module = prepare_fsdp(module, self.device, strategy)
else:
    module = module.to(self.device)
```
If changes are made to this logic, they must be made in both of those classes, which can be easily missed

## This Diff
Creates helper function `_apply_strategy_and_check(...)`  to apply the strategy on the module and calls this function in both `AutoUnit` and `AutoPredictUnit` (other name suggestions are also welcome)

Differential Revision: D48612629

fbshipit-source-id: 4c43193b73a83ad5aaabe69582a213684421729f
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48612629

@codecov
Copy link

codecov bot commented Aug 23, 2023

Codecov Report

Merging #512 (3c1ecc5) into master (43555dd) will increase coverage by 0.10%.
The diff coverage is 64.70%.

@@            Coverage Diff             @@
##           master     #512      +/-   ##
==========================================
+ Coverage   87.19%   87.30%   +0.10%     
==========================================
  Files         106      106              
  Lines        8411     8403       -8     
==========================================
+ Hits         7334     7336       +2     
+ Misses       1077     1067      -10     
Files Changed Coverage Δ
torchtnt/framework/auto_unit.py 82.41% <64.70%> (+2.72%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants